#!/usr/bin/env python3
# converts a saved PyTorch model to ONNX format
import os
import sys
import argparse

import torch.onnx

from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd
from vision.ssd.config import mobilenetv1_ssd_config

# set the device
device = torch.device('cpu')
print(f"=> running on device {device}")

input = "models/best.pth"
# format input model paths

num_classes = 4
resolution = 300
net_name = 'ssd-mobilenet'
# construct the network architecture
print(f"=> creating network:  {net_name}")
print(f"=> num classes:       {num_classes}")
print(f"=> resolution:        {resolution}x{resolution}")

mobilenetv1_ssd_config.set_image_size(300)
net = create_mobilenetv1_ssd(num_classes, is_test=True)
# load the model checkpoint
print(f"=> loading checkpoint:  {input}")

net.load(input)
net.to(device)
net.eval()

# create example image data
dummy_input = torch.randn(1, 3, resolution, resolution)
output = 'mobilenet-ssd.onnx'

# export to ONNX
input_names = ['input_0']
output_names = ['scores', 'boxes']

print("=> exporting model to ONNX...")
torch.onnx.export(net, dummy_input, output, verbose=True, input_names=input_names, output_names=output_names)
print(f"model exported to:  {output}")
